/*
CNN测试套件
验证卷积神经网络的各层功能和整体性能
包含单元测试和集成测试
*/
#include "jsonParser.h"
#include "ConvNeuralNetwork.h"

int main(int argc, char** argv) {
    const string& configFile = "../config/config.json";
    json config;
    readJson(configFile, config);
    visitModel(config);
    visitLayers(config);
    visitTraining(config);
    visitData(config);

    ConvNeuralNetwork cnn(configFile);
    Tensor<double, 3> input(1, 28, 28);
    input.setRandom();
    Tensor<double, 1> label(10);
    label.setRandom();
    Tensor<double, 1> output(10);
    output.setRandom();
    int batchSize = 2;
    int epochs = 8;
    double learningRate = 0.8;
    double momentum = 0.9;
    int verbose = 1;
    // cnn.fit(input, label, batchSize, epochs, learningRate, momentum, verbose);
    cnn.train(input, label, batchSize, epochs, learningRate, momentum);
    cnn.predict(input, output);
    return 0;
}